Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(srt): support prefill and generate with input_embeds #2082

Closed
wants to merge 5 commits into from

Conversation

XuehaiPan
Copy link
Contributor

@XuehaiPan XuehaiPan commented Nov 18, 2024

Motivation

Resolves #745

Modifications

As per the commit messages.

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

@XuehaiPan XuehaiPan changed the title feat(srt/io_struct): support prefill and generate with input_embeds feat(srt): support prefill and generate with input_embeds Nov 18, 2024
@XuehaiPan XuehaiPan force-pushed the generation-input-embeds branch 2 times, most recently from 857750a to 8058e22 Compare November 18, 2024 18:47
@merrymercy
Copy link
Contributor

Thanks for the contribution. There is a related PR recently. Can you take a review on that? #2052

@merrymercy merrymercy mentioned this pull request Nov 18, 2024
3 tasks
@XuehaiPan XuehaiPan force-pushed the generation-input-embeds branch from 976f2c0 to 98331c0 Compare November 19, 2024 17:58
Copy link
Contributor

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution! I left a few comments.
Can you add a test case for llama and llava?

.pre-commit-config.yaml Outdated Show resolved Hide resolved
from enum import Enum
from typing import Dict, List, Optional, Union

from sglang.srt.managers.schedule_batch import BaseFinishReason
from sglang.srt.sampling.sampling_params import SamplingParams

# Use sequence instead of Tensor here because Pydantic serializes Python objects
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sequence or list?

Comment on lines +219 to +226
if sys.version_info >= (3, 10):
_: dataclasses.KW_ONLY
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this used for?

@@ -430,6 +435,9 @@ def __repr__(self):
class ScheduleBatch:
"""Store all inforamtion of a batch."""

if sys.version_info >= (3, 10):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to get rid of this?

@@ -876,7 +902,7 @@ def check_for_jump_forward(self, pad_input_ids_func):
jump_forward_reqs.append(req)
keep_indices.remove(i)

self.filter_batch(keep_indices=list(keep_indices))
self.filter_batch(keep_indices=sorted(keep_indices))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is sorted better here?

(
logits_output,
next_token_ids,
next_token_embeds,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do you need next_token_embeds? I think after the first prefill, we can use token ids and do not need to take embedding inputs anymore.

(
logits_output,
next_token_ids,
next_token_embeds,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think next_token_embeds is probably not necessary here. It makes things much more complicated.
Some of your handling here is not correct as you need to handle the copy of them correctly.
Ideally, we can get rid of next_token_embeds and do not need to change this file.

@@ -211,6 +218,11 @@ def init_new(
forward_mode=batch.forward_mode,
batch_size=len(batch.seq_lens),
input_ids=batch.input_ids,
input_embeds=(
batch.input_embeds.clone().detach().to(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we get rid of this extra copy?

def forward_decode(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
return self.cuda_graph_runner.replay(forward_batch)

forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
self.attn_backend.init_forward_metadata(forward_batch)

if forward_batch.input_embeds is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we probably only need this input_embeds for prefill.

@merrymercy
Copy link
Contributor

I feel #2052 is probably a cleaner solution.

@XuehaiPan XuehaiPan marked this pull request as draft November 22, 2024 13:10
@XuehaiPan XuehaiPan force-pushed the generation-input-embeds branch from 98331c0 to 62e3104 Compare November 22, 2024 13:16
@XuehaiPan XuehaiPan force-pushed the generation-input-embeds branch from 62e3104 to 4e28940 Compare November 23, 2024 07:03
@merrymercy
Copy link
Contributor

will close this one in favor of #2052

@merrymercy merrymercy closed this Nov 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature] Generation Inputs: input_embeds
2 participants